# -*- coding: utf-8 -*-
"""
Created on Wed Feb 5 20:46:37 2025

@author: wxie
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from numpy.fft import fft2, fftshift, ifftshift
from sklearn import linear_model
from tqdm import tqdm
import time

def zernike_z1toz36(z, r, u): # rms convention
    z1 = z[0] * 1*(np.cos(u)**2+np.sin(u)**2)
    z2 = z[1] * 2*r*np.cos(u)
    z3 = z[2] * 2*r*np.sin(u)
    z4 = z[3] * np.sqrt(3)*(2*r**2-1)
    
    z5 = z[4] * np.sqrt(6)*r**2*np.sin(2*u)
    z6 = z[5] * np.sqrt(6)*r**2*np.cos(2*u)
    z7 = z[6] * np.sqrt(8)*(3*r**2-2)*r*np.sin(u)
    z8 = z[7] * np.sqrt(8)*(3*r**2-2)*r*np.cos(u)
    z9 = z[8] * np.sqrt(5)*(6*r**4-6*r**2+1)
    
    z10 = z[9] * np.sqrt(8)*r**3*np.cos(3*u)
    z11 = z[10] * np.sqrt(8)*r**3*np.sin(3*u)
    z12 = z[11] * np.sqrt(10)*(4*r**4-3*r**2)*np.cos(2*u)
    z13 = z[12] * np.sqrt(10)*(4*r**4 - 3*r**2)*np.sin(2*u)
    z14 = z[13] * np.sqrt(12)*(10*r**5 - 12*r**3 + 3*r)*np.cos(u)
    z15 = z[14] * np.sqrt(12)*(10*r**5 - 12*r**3 + 3*r)*np.sin(u)
    z16 = z[15] * np.sqrt(7)*(20*r**6 - 30*r**4 + 12*r**2 - 1)
    
    z17 = z[16] * np.sqrt(10)*r**4*np.cos(4*u)
    z18 = z[17] * np.sqrt(10)*r**4*np.sin(4*u)
    z19 = z[18] * np.sqrt(12)*(5*r**5 - 4*r**3)*np.cos(3*u)
    z20 = z[19] * np.sqrt(12)*(5*r**5 - 4*r**3)*np.sin(3*u)
    z21 = z[20] * np.sqrt(14)*(15*r**6 - 20*r**4 + 6*r**2)*np.cos(2*u)
    z22 = z[21] * np.sqrt(14)*(15*r**6 - 20*r**4 + 6*r**2)*np.sin(2*u)
    z23 = z[22] * np.sqrt(16)*(35*r**7 - 60*r**5 + 30*r**3 - 4*r)*np.cos(u)
    z24 = z[23] * np.sqrt(16)*(35*r**7 - 60*r**5 + 30*r**3 - 4*r)*np.sin(u)            
    z25 = z[24] * np.sqrt(9)*(70*r**8 - 140*r**6 + 90*r**4 - 20*r**2 + 1)

    z26 = z[25] * np.sqrt(12)*r**5*np.cos(5*u)
    z27 = z[26] * np.sqrt(12)*r**5*np.sin(5*u)
    z28 = z[27] * np.sqrt(14)*(6*r**6 - 5*r**4)*np.cos(4*u)
    z29 = z[28] * np.sqrt(14)*(6*r**6 - 5*r**4)*np.sin(4*u)
    z30 = z[29] * np.sqrt(16)*(21*r**7 - 30*r**5 + 10*r**3)*np.cos(3*u)
    z31 = z[30] * np.sqrt(16)*(21*r**7 - 30*r**5 + 10*r**3)*np.sin(3*u)
    z32 = z[31] * np.sqrt(18)*(56*r**8 - 105*r**6 + 60*r**4 - 10*r**2)*np.cos(2*u)
    z33 = z[32] * np.sqrt(18)*(56*r**8 - 105*r**6 + 60*r**4 - 10*r**2)*np.sin(2*u)
    z34 = z[33] * np.sqrt(20)*(126*r**9 - 280*r**7 + 210*r**5 - 60*r**3 + 5*r)*np.cos(u)
    z35 = z[34] * np.sqrt(20)*(126*r**9 - 280*r**7 + 210*r**5 - 60*r**3 + 5*r)*np.sin(u)
    z36 = z[35] * np.sqrt(11)*(252*r**10 - 630*r**8 + 560*r**6 - 210*r**4 + 30*r**2 -1)
            
    phase = z1+z2+z3+z4+z5+z6+z7+z8+z9+z10+z11+z12+z13+z14+z15+z16+z17+z18+z19+z20+z21+z22+z23+z24+z25+z26+z27+z28+z29+z30+z31+z32+z33+z34+z35+z36
    return phase

def Mask(NA, wavelength, pixelSize, numPixels):
    kMax = 2 * np.pi / pixelSize # max spatial frequency
    kNA = 2 * np.pi * NA / wavelength
    dk = 2 * np.pi / (numPixels * pixelSize)
    rpupil = int(kNA/kMax*numPixels)
    x, y = np.meshgrid(np.linspace(-1,1,rpupil*2),np.linspace(-1,1,rpupil*2))
    mask = np.sqrt(x**2 + y**2) <= 1
    return mask, rpupil, dk

def Phase(z_ampl, rpupil, numPixels): # z_ampl: Z1 - Zn
    x, y = np.meshgrid(np.linspace(-1,1,rpupil*2),np.linspace(-1,1,rpupil*2))
    r = np.sqrt(x**2 + y**2)
    theta = np.arctan2(y, x)
    phase = np.zeros((rpupil*2,rpupil*2))
    phase = zernike_z1toz36(z_ampl, r, theta)
    phase[r>1] = 0
    return phase

def Complex_Pupil(mask, phase, rpupil, numPixels):
    pupil = mask * np.exp(1j*2*np.pi*phase)
    pupil_center = np.zeros((numPixels,numPixels))*np.exp(0j)
    pupil_center[numPixels//2-rpupil+1:numPixels//2+rpupil+1,
                 numPixels//2-rpupil+1:numPixels//2+rpupil+1] = pupil
    return pupil_center

def PSF(pupil, dk):
    psf = fftshift(fft2(ifftshift(pupil))) * dk**2
    psf = np.abs(psf)**2 # same as np.abs(psf * np.conj(psf))
    return psf

def Display(x, title, color='jet'):
    plt.imshow(x,color,interpolation = 'nearest')
    plt.colorbar()
    plt.axis('off')
    plt.title(title)
    plt.show()

def Display3D(x, y, z, title, zlim, cmap='jet'):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    surf = ax.plot_surface(x,y,z,vmin=zlim[0],vmax=zlim[1],cmap=cmap)
    contour = ax.contourf(x,y,z,cmap=cmap,offset=zlim[1])
    ax.set_zlim(zlim[0], zlim[1])
    fig.colorbar(surf, shrink=0.5)
    plt.title(title)
    plt.show()

def calculatePSF(mask, rpupil, numPixels, dk, z_ampl, window):
    phase = Phase(z_ampl, rpupil, numPixels)
    pupil = Complex_Pupil(mask, phase, rpupil, numPixels)
    psf = PSF(pupil,dk)
    psf = psf[numPixels//2-window//2:numPixels//2+window//2,
              numPixels//2-window//2:numPixels//2+window//2]
    return psf, phase

def gen_randZernike(Nsamples, Nz, sigma):
    """
    Returns a numpy array of Nsameples Zernike coefficients from Z_2 to Z_Nz
    with a standard deviation sigma.
    """
    Za = np.random.randn(Nsamples, Nz-1)
    return (Za.T / np.sqrt(np.sum(Za**2, axis = -1)) * sigma).T

"""
    Rayleigh Criterion: The maximum wavefront error across a wavefront did not exceed 0.25wave PV (0.07wave RMS), 
    the image quality was not sensibly degraded.
    Total RMS WFE is RSS of Zi from RMS convention.
"""

#%% read training samples
df1 = pd.read_csv('WFE500.csv', header=None)
df2 = pd.read_csv('TrainingImgSet500.csv', header=None)
wfe = df1.to_numpy()
img = df2.to_numpy()

#%% setup
sample = 100 # # of test samples
WFE_max = 0.07 # max RMS WFE
wavelength = 0.55 # um
NA = 0.65
pixelSize = 0.1*wavelength/NA # um
numPixels = 1025 # image size
window = 100 # number of pixels for psf plot
n = 36 # Z2-Zn

# generate ref psf through focus
focus =  np.arange(-500, 501, 50)/1000 # waves np.arange(-50, 51, 5)/1000
nf = len(focus)
z_ref = np.zeros(n)
psf_ref = np.zeros((nf, window**2))
norm = np.zeros(nf)
mask,rpupil,dk = Mask(NA,wavelength,pixelSize,numPixels)
# Display(mask, 'Pupil Mask')
for f in range(nf):
    z_ref[3] = focus[f]
    psf_ref_tmp, _ = calculatePSF(mask,rpupil,numPixels,dk,z_ref,window)
    norm[f] = np.amax(psf_ref_tmp)
    psf_ref[f,:] = psf_ref_tmp.flatten()/norm[f]

#%% model training
# Start the timer
start_time = time.time()

# train model
regr = linear_model.LinearRegression()
regr.fit(img, wfe)

# End the timer
end_time = time.time()

# Calculate the elapsed time
execution_time = end_time - start_time

#%% testing and plot results
fitErr = np.zeros((sample,2))
for i in tqdm(range(sample)):
    z = np.zeros(n)
    WFE_std = WFE_max * np.random.rand()
    z[1:] = gen_randZernike(1, n, WFE_std).flatten() # random Zernikes with std = X Waves
    # z = np.array([ 0.0, -0.00578118,  0.00812718, -0.00950725, -0.01115642,
    #         0.01136958,  0.00303714,  0.00244480, -0.00378351,  0.00481280,
    #         0.00422464, -0.00746966,  0.01109335,  0.0146784 , -0.00215813,
    #         0.00296778,  0.00395313, -0.00581871,  0.01605428, -0.00152662,
    #         0.00970721, -0.00821967,  0.01219332, -0.01100143, -0.00757982,
    #         0.00223501, -0.00265519,  0.01005227, -0.00909843, -0.0084124 ,
    #        -0.01198369,  0.00626413,  0.0105875 ,  0.01390956,  0.00521162,
    #        -0.01756241])
    # _, phase = calculatePSF(mask,rpupil,numPixels,dk,z,window)
    # Display(phase, 'WFE')
    
    psf = np.zeros((nf, window**2))
    for f in range(nf):
        z[3] += focus[f]
        psf_tmp, _ = calculatePSF(mask,rpupil,numPixels,dk,z,window)
        psf[f,:] = psf_tmp.flatten()/norm[f]
        z[3] -= focus[f]
    
    # x, y = np.meshgrid(np.arange(window),np.arange(window))
    # Display3D(x, y, psf_ref[nf//2,:].reshape(window,window), 'PSF (aberration free)', zlim=[0.0,1.0], cmap='magma')
    # Display3D(x, y, psf[nf//2,:].reshape(window,window), 'PSF (with aberration)', zlim=[0.0,1.0], cmap='magma')
    
    z_predicted = regr.predict([psf.flatten()])
    z_predicted = z_predicted.flatten()
    print('Max error = {:7.5f} Waves, RSS error = {:7.5f} Waves'.format(np.amax(abs(z_predicted-z)),
                                                                        np.sqrt(np.sum((z_predicted-z)**2))))
    fitErr[i,:] = (np.amax(abs(z_predicted-z))*1000, np.sqrt(np.sum((z_predicted-z)**2))*1000)

RSSErr_avg = np.round(np.mean(fitErr[:,1]),1)
RSSErr_std = np.round(np.std(fitErr[:,1]),1)

# plot error distribution
plt.figure(figsize=(8, 6))
plt.hist(fitErr[:,1], bins=30, edgecolor='k', alpha=0.7)
plt.title('Deep Learning Error Distribution (Ave = '+str(RSSErr_avg)+', STD = '+str(RSSErr_std)+')')
plt.xlabel('RSS Error (mWaves)')
plt.ylabel('Counts')
plt.show()

# plot one example
barWidth=0.25
br1 = np.arange(len(z_predicted))
br2 = [x + barWidth for x in br1]
br3 = [x + barWidth for x in br2]
plt.bar(br1,z*1000,color='r',width=barWidth,label='Truth')
plt.bar(br2,z_predicted*1000,color='g',width=barWidth,label='Predicted')
plt.legend()
plt.xlabel('Zn')
plt.ylabel('WFE (mWave)')
plt.title('Test Example')
plt.show()

psf_predicted = np.zeros((nf, window**2))
for f in range(nf):
    z_predicted[3] += focus[f]
    psf_predicted_tmp, _ = calculatePSF(mask,rpupil,numPixels,dk,z_predicted,window)
    psf_predicted[f,:] = psf_predicted_tmp.flatten()/norm[f]
    z_predicted[3] -= focus[f]
    

fig, axes = plt.subplots(1, nf//2+1, figsize=(2*(nf//2+1), 2.5))
for i, ax in enumerate(axes):
    im = ax.imshow(psf[i*2,:].reshape(window, window), cmap='magma_r')
    ax.set_title(str(int(focus[i*2]*1000)))
    ax.axis('off')
cbar = fig.colorbar(im, ax=axes, orientation='vertical', fraction=0.02, pad=0.04)
fig.suptitle('True PSF', fontsize=16)
plt.show()

fig, axes = plt.subplots(1, nf//2+1, figsize=(2*(nf//2+1), 2.5))
for i, ax in enumerate(axes):
    im = ax.imshow(psf_predicted[i*2,:].reshape(window, window), cmap='magma_r')
    ax.set_title(str(int(focus[i*2]*1000)))
    ax.axis('off')
cbar = fig.colorbar(im, ax=axes, orientation='vertical', fraction=0.02, pad=0.04)
fig.suptitle('Predicted PSF', fontsize=16)
plt.show()

fig, axes = plt.subplots(1, nf//2+1, figsize=(2*(nf//2+1), 2.5))
for i, ax in enumerate(axes):
    im = ax.imshow((psf_predicted[i*2,:]-psf[i*2,:]).reshape(window, window), cmap='magma_r')
    ax.set_title(str(int(focus[i*2]*1000)))
    ax.axis('off')
cbar = fig.colorbar(im, ax=axes, orientation='vertical', fraction=0.02, pad=0.04)
fig.suptitle('Difference', fontsize=16)
plt.show()